#!/usr/bin/env python

# ---- IMPORT MODULES

try:
    import matplotlib.pyplot as plt
    from matplotlib.font_manager import FontProperties
except:
    raise ImportError("Install 'matplotlib' to plot convergence results.")

# ---- CONVERGENCE PLOT


def ConvergencePlot(cost, save_path):
    """

    Monitors convergence.

    Parameters:
    ----------

        :param dict cost: mean and best cost over cycles/generations as returned
                          by an optimiser.
    """

    font = FontProperties()
    font.set_size("larger")
    # font.set_size("x-large")
    labels = ["Best", "Mean"]
    plt.figure(figsize=(12, 4))
    plt.plot(range(len(cost["best"])), cost["best"], label=labels[0])
    plt.plot(range(len(cost["mean"])), cost["mean"], color="red", label=labels[1])
    plt.xlabel("Iteration #", fontsize=16)
    plt.ylabel("Loss [-]", fontsize=16)
    plt.legend(loc="best", prop=font)
    plt.xlim([0, len(cost["mean"])])
    plt.grid()
    # 保存图片的时候图片不完整了
    plt.tight_layout()
    plt.savefig(save_path)
    # plt.show()
    plt.close()


# ---- END
